{
 "metadata": {
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5-final"
  },
  "orig_nbformat": 2,
  "kernelspec": {
   "name": "python3",
   "display_name": "Python 3",
   "language": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2,
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch \n",
    "from model import Model\n",
    "import pandas as pd \n",
    "import numpy as np\n",
    "from torch.utils.data import Dataset,DataLoader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "class data(Dataset):\n",
    "\n",
    "    def __init__(self,path):\n",
    "        self.dataset = np.array(pd.read_csv(path))\n",
    "        self.map = {'draw': 1, 'erase': 0}\n",
    "\n",
    "    def __getitem__(self,index):\n",
    "        yvals = torch.tensor([self.map[self.dataset[index][1]]], dtype=torch.float)\n",
    "        xvals = torch.from_numpy(self.dataset[index][2:].astype(np.float32))\n",
    "        return xvals,yvals\n",
    "\n",
    "    def __len__(self):\n",
    "        return self.dataset.shape[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size = 5\n",
    "dataset = data('/Users/gursi/Desktop/virtual_drawing_board/dataset/dataset.csv')\n",
    "loader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "EPOCH :  0\n",
      "  |  Loss :  76.07135772705078\n",
      "EPOCH :  1\n",
      "  |  Loss :  70.83623504638672\n",
      "EPOCH :  2\n",
      "  |  Loss :  62.52872085571289\n",
      "EPOCH :  3\n",
      "  |  Loss :  62.942626953125\n",
      "EPOCH :  4\n",
      "  |  Loss :  62.76849365234375\n",
      "EPOCH :  5\n",
      "  |  Loss :  69.89595031738281\n",
      "EPOCH :  6\n",
      "  |  Loss :  47.753761291503906\n",
      "EPOCH :  7\n",
      "  |  Loss :  55.33074188232422\n",
      "EPOCH :  8\n",
      "  |  Loss :  81.68561553955078\n",
      "EPOCH :  9\n",
      "  |  Loss :  52.624629974365234\n",
      "EPOCH :  10\n",
      "  |  Loss :  46.77412796020508\n",
      "EPOCH :  11\n",
      "  |  Loss :  32.127349853515625\n",
      "EPOCH :  12\n",
      "  |  Loss :  35.55134582519531\n",
      "EPOCH :  13\n",
      "  |  Loss :  33.870018005371094\n",
      "EPOCH :  14\n"
     ]
    },
    {
     "output_type": "error",
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-4-970039a2fe01>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m     15\u001b[0m         \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcrit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0myhat\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0my\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     16\u001b[0m         \u001b[0mtotal_loss\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0mloss\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 17\u001b[0;31m         \u001b[0mloss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     18\u001b[0m         \u001b[0mopt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     19\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/opt/anaconda3/envs/asl/lib/python3.8/site-packages/torch/tensor.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(self, gradient, retain_graph, create_graph)\u001b[0m\n\u001b[1;32m    219\u001b[0m                 \u001b[0mretain_graph\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    220\u001b[0m                 create_graph=create_graph)\n\u001b[0;32m--> 221\u001b[0;31m         \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mautograd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgradient\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    222\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    223\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mregister_hook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/opt/anaconda3/envs/asl/lib/python3.8/site-packages/torch/autograd/__init__.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables)\u001b[0m\n\u001b[1;32m    128\u001b[0m         \u001b[0mretain_graph\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    129\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 130\u001b[0;31m     Variable._execution_engine.run_backward(\n\u001b[0m\u001b[1;32m    131\u001b[0m         \u001b[0mtensors\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgrad_tensors_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    132\u001b[0m         allow_unreachable=True)  # allow_unreachable flag\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "epochs = 30 \n",
    "lr = 0.001\n",
    "model = Model()\n",
    "crit = torch.nn.BCELoss()\n",
    "opt = torch.optim.Adam(model.parameters(), lr=lr)\n",
    "loss_list = []\n",
    "\n",
    "for epoch in range(epochs):\n",
    "    print('EPOCH : ', epoch)\n",
    "    total_loss = 0\n",
    "    \n",
    "    for x,y in loader : \n",
    "        opt.zero_grad()\n",
    "        yhat = model.forward(x)\n",
    "        loss = crit(yhat,y)\n",
    "        total_loss += loss\n",
    "        loss.backward()\n",
    "        opt.step()\n",
    "\n",
    "    print('  |  Loss : ',total_loss.item(), end='')\n",
    "    print()\n",
    "    loss_list.append(total_loss.item())"
   ]
  }
 ]
}